#############
#############
#############
## Stefan Müller and Sven-Oliver Proksch:
## Nostalgia in European Party Politics:
## A Text-Based Measurement Approach
## British Journal of Political Science
##
## Python script to fine-tune and cross-validate
## nostalgia DistilBERT classifier
#############
#############
#############

## If you would like to apply the fine-tuned
## classifier to your own data, check the following file:
## 07_tutorial_classify_nostalgia_distilbert.ipynb

## install packages
# !pip install datasets
# !pip install transformers
# !pip install sklearn
# !pip3 install torch torchvision
# !pip install psutil

## load packages and functions
from datasets import load_dataset, load_metric
from transformers import AutoTokenizer, TrainingArguments, AutoModelForSequenceClassification, Trainer
import numpy as np

import torch
import pandas as pd
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score
import os

# os.environ["TOKENIZERS_PARALLELISM"] = "true"

## 1. Pre-processing

## determine category for training and prediction
label_train = "nostalgic"

# load training dataset
dataset_dict = load_dataset("csv", data_files="data_coded_train.csv")
dataset = dataset_dict["train"]

# should have 960 rows
print(dataset)

## load tokenizer
tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')

## preprocessing functions
def transform_labels(label):
    label = label[label_train]
    return {'labels': label}

## tokenization function
## (padding and trunctuation ensure that very short/long texts are considered)
def tokenize_data(example):
    return tokenizer(example['text'], padding='max_length', truncation=True)

## tokenize text
dataset = dataset.map(tokenize_data, batched=True)

# Remove unnecessary columns and ensure labels in correct format
remove_columns = ["text", "countryname", "translation_inaccurate", "nostalgic"]

dataset = dataset.map(transform_labels, remove_columns = remove_columns)

## split data into training and evaluation sets
dataset = dataset.train_test_split(test_size = 0.2, shuffle = False)

## shuffle = FALSE ensures that we use the first 960 observations for training

train_dataset = dataset['train']
train_dataset
# 768 observations

eval_dataset = dataset['test']
eval_dataset
# 192 observations

## fine tune model based on F1 scores and 3 epochs
## https://huggingface.co/docs/transformers/training
metric = load_metric("f1")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)


training_args = TrainingArguments(output_dir="test_trainer",
evaluation_strategy="epoch", per_device_eval_batch_size=128)

## https://huggingface.co/bert-base-cased
model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)


## 2. Train downstream task

## build trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset, # use traning set
    eval_dataset=eval_dataset, # use evaluation dataset for fine-tuning
    compute_metrics=compute_metrics, # consider f1 score for fine-tuning
)


## train model (takes time!)
trainer.train()

## save model to drive so it can be reused
trainer.save_model("distilbert_nostalgic")

## load fine-tuned model
model = AutoModelForSequenceClassification.from_pretrained("distilbert_models/distilbert_nostalgic", num_labels=2)


## 3. Apply to unseen data

## fill in path to unlabelled data (created in 01c_assess_handcoding_round_03.R)
unlab_data_path = "data_coded_test.csv"

unlab_dataset = load_dataset("csv", data_files=unlab_data_path)

# should be 240 documents
print(unlab_dataset)

## preprocess unlabelled data in same way as training data
unlab_dataset = unlab_dataset.map(tokenize_data, batched=True)

## use model to predict probabilities for each label
preds = trainer.predict(unlab_dataset["train"])
pred_labels = np.argmax(preds[0], axis=1)

# merge labels with dataset and save as CSV
output = pd.read_csv(unlab_data_path)
output[label_train + "_bert"] = pred_labels
output.to_csv("data_classified_bert_test.csv")


from transformers import TextClassificationPipeline

text1 = ["Let's return to the good old days and consider our cultural heritage."]

# test classifier in pipeline
pipe = TextClassificationPipeline(model=model, tokenizer=tokenizer)
pipe(text1)

